
# -*- coding:utf-8 -*-
import argparse
import logging
import os

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

from utils.data_loading import BasicDataset
from unet import UNet,U2NETP,U2NET
from utils.utils import plot_img_and_mask

from utils import config
import sys
import codecs
sys.stdout = codecs.getwriter("utf-8")(sys.stdout.detach())

def predict_img(net,
                full_img,
                device,
                config=config,
                use_net=None,
                out_threshold=0.5):
    net.eval()
    img = torch.from_numpy(BasicDataset.preprocess(full_img, config, use_net,is_mask=False))
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)  #[1, 3, 320, 600])

        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)[0]  #[3, 320, 600])
        else:
            probs = torch.sigmoid(output)[0]


        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((full_img.size[1], full_img.size[0])),
            transforms.ToTensor()
        ])


        full_mask = tf(probs.cpu()).squeeze()  #<class tensor>


    if net.n_classes == 1:
        return (full_mask > out_threshold).numpy()
    else:
        return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()


def get_args():
    parser = argparse.ArgumentParser(description='Predict masks from input images')
    parser.add_argument('--net', type=str, default="U2NET", help='UNet,U2NET,U2NETP')
    parser.add_argument('--model', '-m', default='checkpoints/U2NET/checkpoint_epoch200.pth', metavar='FILE',
                        help='Specify the file in which the model is stored')
    parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',help='Filenames of input images', required=True)
    parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', help='Filenames of output images')
    parser.add_argument('--viz', '-v', action='store_true',
                        help='Visualize the images as they are processed')
    parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks')
    parser.add_argument('--mask-threshold', '-t', type=float, default=0.5,
                        help='Minimum probability value to consider a mask pixel white')
    parser.add_argument('--scale', '-s', type=float, default=0.6,
                        help='Scale factor for the input images')

    return parser.parse_args()


def get_output_filenames(args):
    def _generate_name(fn):
        split = os.path.splitext(fn)
        print(split)
        print(split[0])
        print(split[1])
        return f'{split[0]}_OUT{split[1]}'

    return args.output or list(map(_generate_name, args.input))


def mask_to_image(mask: np.ndarray):
    if mask.ndim == 2:
        return Image.fromarray((mask * 255).astype(np.uint8))
    elif mask.ndim == 3:
        return Image.fromarray((np.argmax(mask, axis=0) * 255 / mask.shape[0]).astype(np.uint8))

'python predict.py -o output/output1.png -i Datasample/2021-10-17_14_45_34.png'
if __name__ == '__main__':
    import time
    t1=time.time()
    args = get_args()
    in_files = args.input
    out_files = get_output_filenames(args)
    if args.net=="U2NET":
        net = U2NET(n_channels=3, n_classes=3)
    elif args.net =="U2NETP":
        net = U2NETP(n_channels=3, n_classes=3)
    else:
        net = UNet(n_channels=3, n_classes=3)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Loading model {args.model}')
    logging.info(f'Using device {device}')

    net.to(device=device)
    net.load_state_dict(torch.load(args.model, map_location=device))

    logging.info('Model loaded!')
    t2=time.time()
    for i, filename in enumerate(in_files):
        logging.info(f'\nPredicting image {filename} ...')
        img = Image.open(filename)

        mask = predict_img(net=net,
                           full_img=img,
                           # scale_factor=args.scale,
                           config=config,
                           use_net=args.net,
                           out_threshold=args.mask_threshold,
                           device=device)  #(3, 537, 1000)
        # mask=mask[2]
        if not args.no_save:
            out_filename = out_files[i]
            result = mask_to_image(mask)
            result.save(out_filename)
            logging.info(f'Mask saved to {out_filename}')

        if args.viz:
            logging.info(f'Visualizing results for image {filename}, close to continue...')
            plot_img_and_mask(img, mask)
    t3=time.time()
    print(t2-t1)
    print(t3-t1)